{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# K-Fold Cross-Validation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Importing libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data preparation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# importing data\n",
    "data_path= '../data/diamonds.csv'\n",
    "diamonds = pd.read_csv(data_path)\n",
    "diamonds = pd.concat([diamonds, pd.get_dummies(diamonds['cut'], prefix='cut', drop_first=True)],axis=1)\n",
    "diamonds = pd.concat([diamonds, pd.get_dummies(diamonds['color'], prefix='color', drop_first=True)],axis=1)\n",
    "diamonds = pd.concat([diamonds, pd.get_dummies(diamonds['clarity'], prefix='clarity', drop_first=True)],axis=1)\n",
    "diamonds.drop(['cut','color','clarity'], axis=1, inplace=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Diamonds dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preparing objects for modelling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import RobustScaler\n",
    "target_name = 'price'\n",
    "robust_scaler = RobustScaler()\n",
    "X = diamonds.drop('price', axis=1)\n",
    "X = robust_scaler.fit_transform(X)\n",
    "y = diamonds[target_name]\n",
    "# Notice that we are not doing train-test split\n",
    "#X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=55)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training our model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from sklearn.ensemble import RandomForestRegressor\n",
    "RF = RandomForestRegressor(n_estimators=50, max_depth=16, random_state=123, n_jobs=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# this will work from sklearn version 0.19, if you get an error \n",
    "# make sure you upgrade: $conda upgrade scikit-learn\n",
    "from sklearn.model_selection import cross_validate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "scores = cross_validate(estimator=RF,X=X,y=y,\n",
    "                        scoring=['mean_squared_error','r2'],\n",
    "                        cv=10, n_jobs=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style>\n",
       "    .dataframe thead tr:only-child th {\n",
       "        text-align: right;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>fit_time</th>\n",
       "      <th>score_time</th>\n",
       "      <th>test_mean_squared_error</th>\n",
       "      <th>test_r2</th>\n",
       "      <th>train_mean_squared_error</th>\n",
       "      <th>train_r2</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2.704191</td>\n",
       "      <td>0.720918</td>\n",
       "      <td>3.755390e+05</td>\n",
       "      <td>0.538764</td>\n",
       "      <td>148065.528065</td>\n",
       "      <td>0.991526</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3.141356</td>\n",
       "      <td>0.988628</td>\n",
       "      <td>4.506041e+05</td>\n",
       "      <td>0.672636</td>\n",
       "      <td>150123.441197</td>\n",
       "      <td>0.991437</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3.756991</td>\n",
       "      <td>1.060821</td>\n",
       "      <td>1.429308e+06</td>\n",
       "      <td>0.386105</td>\n",
       "      <td>118993.885068</td>\n",
       "      <td>0.993105</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3.542923</td>\n",
       "      <td>1.004674</td>\n",
       "      <td>2.386801e+06</td>\n",
       "      <td>0.569107</td>\n",
       "      <td>121708.194620</td>\n",
       "      <td>0.992298</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>3.403554</td>\n",
       "      <td>1.176127</td>\n",
       "      <td>6.002576e+06</td>\n",
       "      <td>0.653763</td>\n",
       "      <td>84805.134870</td>\n",
       "      <td>0.990100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>3.737440</td>\n",
       "      <td>0.910923</td>\n",
       "      <td>1.376623e+06</td>\n",
       "      <td>0.958366</td>\n",
       "      <td>134400.626049</td>\n",
       "      <td>0.990314</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>3.839710</td>\n",
       "      <td>4.791745</td>\n",
       "      <td>2.447721e+04</td>\n",
       "      <td>-0.314355</td>\n",
       "      <td>149193.566169</td>\n",
       "      <td>0.990960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>5.881141</td>\n",
       "      <td>0.306817</td>\n",
       "      <td>6.405753e+04</td>\n",
       "      <td>-0.214988</td>\n",
       "      <td>149713.173174</td>\n",
       "      <td>0.991024</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>5.870614</td>\n",
       "      <td>0.363968</td>\n",
       "      <td>1.156133e+05</td>\n",
       "      <td>0.304016</td>\n",
       "      <td>156899.220946</td>\n",
       "      <td>0.990759</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>6.064633</td>\n",
       "      <td>0.298291</td>\n",
       "      <td>1.976350e+05</td>\n",
       "      <td>0.396521</td>\n",
       "      <td>154009.670050</td>\n",
       "      <td>0.991083</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   fit_time  score_time  test_mean_squared_error   test_r2  \\\n",
       "0  2.704191    0.720918             3.755390e+05  0.538764   \n",
       "1  3.141356    0.988628             4.506041e+05  0.672636   \n",
       "2  3.756991    1.060821             1.429308e+06  0.386105   \n",
       "3  3.542923    1.004674             2.386801e+06  0.569107   \n",
       "4  3.403554    1.176127             6.002576e+06  0.653763   \n",
       "5  3.737440    0.910923             1.376623e+06  0.958366   \n",
       "6  3.839710    4.791745             2.447721e+04 -0.314355   \n",
       "7  5.881141    0.306817             6.405753e+04 -0.214988   \n",
       "8  5.870614    0.363968             1.156133e+05  0.304016   \n",
       "9  6.064633    0.298291             1.976350e+05  0.396521   \n",
       "\n",
       "   train_mean_squared_error  train_r2  \n",
       "0             148065.528065  0.991526  \n",
       "1             150123.441197  0.991437  \n",
       "2             118993.885068  0.993105  \n",
       "3             121708.194620  0.992298  \n",
       "4              84805.134870  0.990100  \n",
       "5             134400.626049  0.990314  \n",
       "6             149193.566169  0.990960  \n",
       "7             149713.173174  0.991024  \n",
       "8             156899.220946  0.990759  \n",
       "9             154009.670050  0.991083  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scores = pd.DataFrame(scores)\n",
    "scores['test_mean_squared_error'] = -1*scores['test_mean_squared_error']\n",
    "scores['train_mean_squared_error'] = -1*scores['train_mean_squared_error']\n",
    "scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean test MSE: 1242323\n",
      "Mean test R-squared: 0.39499334944982994\n"
     ]
    }
   ],
   "source": [
    "print(\"Mean test MSE:\", round(scores['test_mean_squared_error'].mean()))\n",
    "print(\"Mean test R-squared:\", scores['test_r2'].mean())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Credit card default dataset "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preparing the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "default = pd.read_csv('../data/credit_card_default.csv', index_col=\"ID\")\n",
    "default.rename(columns=lambda x: x.lower(), inplace=True)\n",
    "default.rename(columns={'pay_0':'pay_1','default payment next month':'default'}, inplace=True)\n",
    "# Base values: female, other_education, not_married\n",
    "default['grad_school'] = (default['education'] == 1).astype('int')\n",
    "default['university'] = (default['education'] == 2).astype('int')\n",
    "default['high_school'] = (default['education'] == 3).astype('int')\n",
    "default.drop('education', axis=1, inplace=True)\n",
    "\n",
    "default['male'] = (default['sex']==1).astype('int')\n",
    "default.drop('sex', axis=1, inplace=True)\n",
    "\n",
    "default['married'] = (default['marriage'] == 1).astype('int')\n",
    "default.drop('marriage', axis=1, inplace=True)\n",
    "\n",
    "# For pay_n features if >0 then it means the customer was delayed on that month\n",
    "pay_features = ['pay_' + str(i) for i in range(1,7)]\n",
    "for p in pay_features:\n",
    "    default[p] = (default[p] > 0).astype(int)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preparing objects for modelling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "target_name = 'default'\n",
    "X_credit = default.drop('default', axis=1)\n",
    "feature_names = X_credit.columns\n",
    "robust_scaler = RobustScaler()\n",
    "X_credit = robust_scaler.fit_transform(X_credit)\n",
    "y_credit = default[target_name]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier\n",
    "RF_credit = RandomForestClassifier(n_estimators=35, max_depth=20, random_state=55, \n",
    "                                   max_features='sqrt', n_jobs=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "scores_credit = cross_validate(estimator=RF_credit, X=X_credit, y=y_credit,\n",
    "                        scoring=['accuracy','precision','recall'],\n",
    "                        cv=10, n_jobs=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style>\n",
       "    .dataframe thead tr:only-child th {\n",
       "        text-align: right;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>fit_time</th>\n",
       "      <th>score_time</th>\n",
       "      <th>test_accuracy</th>\n",
       "      <th>test_precision</th>\n",
       "      <th>test_recall</th>\n",
       "      <th>train_accuracy</th>\n",
       "      <th>train_precision</th>\n",
       "      <th>train_recall</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.865801</td>\n",
       "      <td>0.322090</td>\n",
       "      <td>0.795068</td>\n",
       "      <td>0.563636</td>\n",
       "      <td>0.326807</td>\n",
       "      <td>0.949554</td>\n",
       "      <td>0.996981</td>\n",
       "      <td>0.774113</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.981133</td>\n",
       "      <td>0.315338</td>\n",
       "      <td>0.801400</td>\n",
       "      <td>0.589005</td>\n",
       "      <td>0.338855</td>\n",
       "      <td>0.949924</td>\n",
       "      <td>0.995921</td>\n",
       "      <td>0.776792</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1.001128</td>\n",
       "      <td>0.315592</td>\n",
       "      <td>0.802732</td>\n",
       "      <td>0.591940</td>\n",
       "      <td>0.352410</td>\n",
       "      <td>0.949517</td>\n",
       "      <td>0.997841</td>\n",
       "      <td>0.773443</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.952572</td>\n",
       "      <td>0.317346</td>\n",
       "      <td>0.796734</td>\n",
       "      <td>0.569231</td>\n",
       "      <td>0.334337</td>\n",
       "      <td>0.953109</td>\n",
       "      <td>0.996623</td>\n",
       "      <td>0.790690</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.763571</td>\n",
       "      <td>0.317416</td>\n",
       "      <td>0.807333</td>\n",
       "      <td>0.614973</td>\n",
       "      <td>0.346386</td>\n",
       "      <td>0.950111</td>\n",
       "      <td>0.995288</td>\n",
       "      <td>0.778131</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>0.728374</td>\n",
       "      <td>0.315622</td>\n",
       "      <td>0.808333</td>\n",
       "      <td>0.594080</td>\n",
       "      <td>0.423193</td>\n",
       "      <td>0.950926</td>\n",
       "      <td>0.995521</td>\n",
       "      <td>0.781648</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>0.734450</td>\n",
       "      <td>0.316079</td>\n",
       "      <td>0.827943</td>\n",
       "      <td>0.686076</td>\n",
       "      <td>0.408748</td>\n",
       "      <td>0.947743</td>\n",
       "      <td>0.995439</td>\n",
       "      <td>0.767286</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>1.205706</td>\n",
       "      <td>0.588914</td>\n",
       "      <td>0.826275</td>\n",
       "      <td>0.707602</td>\n",
       "      <td>0.365008</td>\n",
       "      <td>0.950743</td>\n",
       "      <td>0.997643</td>\n",
       "      <td>0.779508</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>1.238292</td>\n",
       "      <td>0.590923</td>\n",
       "      <td>0.817272</td>\n",
       "      <td>0.670623</td>\n",
       "      <td>0.340875</td>\n",
       "      <td>0.947261</td>\n",
       "      <td>0.996073</td>\n",
       "      <td>0.764105</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>1.768040</td>\n",
       "      <td>0.314352</td>\n",
       "      <td>0.810937</td>\n",
       "      <td>0.631148</td>\n",
       "      <td>0.348416</td>\n",
       "      <td>0.947261</td>\n",
       "      <td>0.997158</td>\n",
       "      <td>0.763770</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   fit_time  score_time  test_accuracy  test_precision  test_recall  \\\n",
       "0  0.865801    0.322090       0.795068        0.563636     0.326807   \n",
       "1  0.981133    0.315338       0.801400        0.589005     0.338855   \n",
       "2  1.001128    0.315592       0.802732        0.591940     0.352410   \n",
       "3  0.952572    0.317346       0.796734        0.569231     0.334337   \n",
       "4  0.763571    0.317416       0.807333        0.614973     0.346386   \n",
       "5  0.728374    0.315622       0.808333        0.594080     0.423193   \n",
       "6  0.734450    0.316079       0.827943        0.686076     0.408748   \n",
       "7  1.205706    0.588914       0.826275        0.707602     0.365008   \n",
       "8  1.238292    0.590923       0.817272        0.670623     0.340875   \n",
       "9  1.768040    0.314352       0.810937        0.631148     0.348416   \n",
       "\n",
       "   train_accuracy  train_precision  train_recall  \n",
       "0        0.949554         0.996981      0.774113  \n",
       "1        0.949924         0.995921      0.776792  \n",
       "2        0.949517         0.997841      0.773443  \n",
       "3        0.953109         0.996623      0.790690  \n",
       "4        0.950111         0.995288      0.778131  \n",
       "5        0.950926         0.995521      0.781648  \n",
       "6        0.947743         0.995439      0.767286  \n",
       "7        0.950743         0.997643      0.779508  \n",
       "8        0.947261         0.996073      0.764105  \n",
       "9        0.947261         0.997158      0.763770  "
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scores_credit = pd.DataFrame(scores_credit)\n",
    "scores_credit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "test_accuracy     0.809403\n",
       "test_precision    0.621831\n",
       "test_recall       0.358503\n",
       "dtype: float64"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scores_credit[['test_accuracy','test_precision','test_recall']].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "test_accuracy     0.011415\n",
       "test_precision    0.050435\n",
       "test_recall       0.032185\n",
       "dtype: float64"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scores_credit[['test_accuracy','test_precision','test_recall']].std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
